Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph: introduce internal dnnl_sdpa op #2930

Merged
merged 2 commits into from
Mar 26, 2025
Merged

Conversation

xiang1guo
Copy link
Contributor

@xiang1guo xiang1guo commented Mar 21, 2025

Description

  • This is the first part to support GQA pattern refinement requested by PyTorch upstream. Address MFDNN-12871. The PR added a new internal dnnl_sdpa op for better fusion and alignment with sdpa primitive. With this, the internal compilation will transform a subgraph into sdpa op, see following pic.
  • The refactor also aims to reduce graph compilation time by reducing layout propagation and memory planning time with a simplified internal sdpa op.

Compiled graph before this PR

image

Compiled graph after this PR

image

Works

  • Added a dnnl_sdpa op. Currently only support float SDPA.
  • Added a new sdp_primitive_v1 kernel for simplicity. The final goal is to merge this kernel with sdp_primitive kernel.
  • Move the sdpa primitive ukernel creation process into op_executable which is now aligned with other kernels

Follow-up

There will be another PR to refine the GQA pattern based on this new internal dnnl_sdpa.

Validation

Correctness check

There are total 66 case can be supported by GPU ukernel, those 42 float SDPA cases can run into sdp_primitive_v1_kernel_t now, the other 24 cases are quantization SDPA which all run into sdp_primitive_kernel_t .

Kernel PR main branch
sdp_primitive_v1_kernel_t 42 0
sdp_primitive_kernel_t 24 66
larger_partition_kernel_t 76 76

Performance test

  • SDPA partition level performance has comparable performance with main branch.
  • Graph compilation time of v1 kernel(exclude primitive creation time) got 2x speedup compared to legacy sdp_primitive_kernel.

@github-actions github-actions bot added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Mar 21, 2025
@xiang1guo xiang1guo self-assigned this Mar 21, 2025
@xiang1guo xiang1guo force-pushed the xiang/main/internal-sdpa-part1 branch from a2f5367 to 6145d95 Compare March 21, 2025 14:58
@xiang1guo xiang1guo marked this pull request as ready for review March 24, 2025 01:37
@xiang1guo xiang1guo requested a review from a team as a code owner March 24, 2025 01:37
@rongzha1
Copy link
Contributor

Graph compilation time of v1 kernel(exclude primitive creation time) got 2x speedup compared to legacy sdp_primitive_kernel.

What were the main optimizations that led to the 2x performance gain?

@xiang1guo
Copy link
Contributor Author

Graph compilation time of v1 kernel(exclude primitive creation time) got 2x speedup compared to legacy sdp_primitive_kernel.

What were the main optimizations that led to the 2x performance gain?

Firstly, the compilation performance bottle neck is layout_propagation and memory_planning pass. Layout propagation for each op will try to create primitive descriptor to get the optimal layout. Memory planning will go through the whole graph to plan the in/out/internal memory size for subgraph. This process will try to create a bunch of md.
Secondly, as you can see from the attached subgraph before and after the change, we fused several internal op(matmul/softmax/binary) into 1 sdpa op, for layout propagation, we saved a lot of propagation process and reduced the pd creation time. For memory planning, no need to plan internal tmp memory, no more md creation.

@xiang1guo
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@xiang1guo xiang1guo force-pushed the xiang/main/internal-sdpa-part1 branch from 4560e25 to e0ded0f Compare March 25, 2025 03:45
@xiang1guo
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@TaoLv
Copy link
Contributor

TaoLv commented Mar 25, 2025

@xiang1guo please fix the clang tidy warnings:

clang-tidy warnings

Analyzing src/graph/backend/dnnl/dnnl_shape_infer.cpp
2264 warnings generated.
Suppressed 2395 warnings (2264 in non-user code, 131 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/kernels/large_partition.cpp
3008 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3155 warnings (3006 in non-user code, 149 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/kernels/matmul.cpp
3002 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3149 warnings (3000 in non-user code, 149 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/kernels/mqa_decomp.cpp
3038 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3186 warnings (3036 in non-user code, 150 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/kernels/sdp_decomp.cpp
3099 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3247 warnings (3097 in non-user code, 150 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/kernels/sdp_primitive.cpp
2992 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3142 warnings (2990 in non-user code, 152 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
2739 warnings generated.
Suppressed 2888 warnings (2739 in non-user code, 149 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp
2992 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3142 warnings (2990 in non-user code, 152 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/layout_propagator.cpp
2876 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3023 warnings (2874 in non-user code, 149 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/op_executable.cpp
3110 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3257 warnings (3108 in non-user code, 149 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/passes/compile_ops.cpp
2878 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3025 warnings (2876 in non-user code, 149 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.

Analyzing src/graph/backend/dnnl/passes/transform.cpp
2937 warnings generated.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
2647 | fusion_info_mgr_t &mgr, pd_cache_t &pd_cache) {
| : with_scale_(op->get_attr(op_attr::with_scale))
2648 |
2649 | auto md_q = make_dnnl_memory_desc(
2650 | op->get_input_value(0)->get_logical_tensor());
2651 | auto md_k = make_dnnl_memory_desc(
2652 | op->get_input_value(1)->get_logical_tensor());
2653 | auto md_v = make_dnnl_memory_desc(
2654 | op->get_input_value(2)->get_logical_tensor());
2655 | auto md_dst = make_dnnl_memory_desc(
2656 | op->get_output_value(0)->get_logical_tensor());
2657 |
2658 | auto scale_dt = impl::data_type::undef;
2659 | size_t idx = 3;
2660 | with_scale_ = op->get_attr(op_attr::with_scale);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]
2718 | return;
| ^~~~~~~
2719 | }
Suppressed 3084 warnings (2935 in non-user code, 149 NOLINT).
Use -header-filter=.* to display errors from all non-system headers. Use -system-headers to display errors from system headers as well.
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2660:9: warning: 'with_scale_' should be initialized in a member initializer of the constructor [cppcoreguidelines-prefer-member-initializer]
/home/runner/work/oneDNN/oneDNN/src/graph/backend/dnnl/op_executable.hpp:2718:9: warning: redundant return statement at the end of a function with a void return type [readability-redundant-control-flow]

@xiang1guo xiang1guo force-pushed the xiang/main/internal-sdpa-part1 branch from e0ded0f to 71905d2 Compare March 25, 2025 12:14
@xiang1guo
Copy link
Contributor Author

xiang1guo commented Mar 25, 2025

@xiang1guo please fix the clang tidy warnings:
clang-tidy warnings

Thanks for the remind, fixed and reorganized the commit, please check again. The remaining clang-tidy warnings seems not related to the PR changes.

@xiang1guo
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@TaoLv TaoLv merged commit 6bc6597 into main Mar 26, 2025
12 of 13 checks passed
@TaoLv TaoLv deleted the xiang/main/internal-sdpa-part1 branch March 26, 2025 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:graph-api Codeowner: @oneapi-src/onednn-graph
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants